from Record.file_management import read_obj_dumps, load_from_pickle, strip_instance
from Vae.record_vae_and_masks_state import load_encodings
import numpy as np

def set_dict_min(enc_dict):
    enc_dict[enc_dict > -0.01] = -0.01

def set_dict_max(enc_dict):
    enc_dict[enc_dict < 0.01] = 0.01


def compute_encoder_mean_variance(args):
    if args.train.load_encodings:
        data = read_obj_dumps(args.train.load_rollouts, i=-1, rng = args.train.num_frames, filename='object_dumps.txt')
        encodings = load_encodings(args.train.load_encodings)[min(0, -args.train.num_frames-1):] if len(args.train.load_encodings) > 0 else None
        encodings = [d for d in encodings if d != {}] # shouldn't have to do this
        encodings = encodings[min(0, -args.train.num_frames-1):] if len(args.train.load_encodings) > 0 else None
        all_encodings, all_encodings_diff = dict(), dict()
        last_encoding = encodings[0]
        for next_factored_state, encoding in zip(data[1:], encodings[1:]):
            ek, lek = list(encoding.keys()), list(last_encoding.keys())
            ek.sort(), lek.sort()
            if not next_factored_state["Done"]:
                for n in encoding.keys():
                    n_type = strip_instance(n)
                    if n_type not in all_encodings: all_encodings[n_type], all_encodings_diff[n_type] = list(), list()
                    encoding[n] = np.nan_to_num(encoding[n], copy=True, nan=0.0, posinf=None, neginf=None)
                    all_encodings[n_type].append(encoding[n])
                    all_encodings_diff[n_type].append(last_encoding[n].astype(float) - encoding[n].astype(float))
            last_encoding = encoding
        # for n in all_encodings.keys():
        #     print(n)
        #     print(np.min(all_encodings[n], axis=0), np.max(all_encodings[n], axis=0), np.mean(all_encodings[n], axis=0),
        #                                 np.min(all_encodings_diff[n], axis=0), np.max(all_encodings_diff[n], axis=0), np.mean(all_encodings_diff[n], axis=0))
        enc_range, enc_dyn_range = dict(), dict()

        for n in all_encodings.keys():
        
            enc_range[n], enc_dyn_range[n] = (np.min(all_encodings[n], axis=0), np.max(all_encodings[n], axis=0)), (np.min(all_encodings_diff[n], axis=0), np.max(all_encodings_diff[n], axis=0))
            set_dict_min(enc_range[n][0])
            set_dict_max(enc_range[n][1])
            set_dict_min(enc_dyn_range[n][0])
            set_dict_max(enc_dyn_range[n][1])
        # print(enc_range)
        # print(len(list(enc_dyn_range.keys())))
        return enc_range, enc_dyn_range
    return None, None